{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6d8683b7-0fab-4937-b7ad-72d70a0260ac",
   "metadata": {},
   "source": [
    "# tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d33c6c4e-9fd7-4850-8800-12ac35a867a0",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: wonderwords in /usr/local/anaconda3/envs/ghostaienv/lib/python3.8/site-packages (2.2.0)\n"
     ]
    }
   ],
   "source": [
    "!pip install wonderwords"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "591305e5-0459-4f0b-9968-f77d207d0172",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os, json\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3cca5b71-75c5-44bc-9e80-330c93915f3d",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. If you see this, DO NOT PANIC! This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=True`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n"
     ]
    }
   ],
   "source": [
    "from transformers import LlamaTokenizer\n",
    "import torch\n",
    "\n",
    "base_model = \"huggyllama/llama-13b\"\n",
    "tokenizer = LlamaTokenizer.from_pretrained(base_model,\n",
    "                                          )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "041ff3ce-d593-4f7d-be0e-c5488aeb9156",
   "metadata": {},
   "source": [
    "# gen topic eval dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "97a48cc7-7c41-4e7d-96ca-4771472a3e81",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "\n",
    "np.random.seed(42) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1409125a-2030-4067-9a33-8612c4cd668b",
   "metadata": {},
   "source": [
    "# gen lines eval dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4bf7c52c-ce66-483e-9a6a-c6067d1dbdeb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "def generate_line_index(num_line, idx_opt):\n",
    "    if idx_opt == \"LRT-ABCindex\":\n",
    "        ingredients = [\"A\", \"B\", \"C\", \"D\", \"E\", \"F\"]\n",
    "\n",
    "        start = 6\n",
    "        comb = list(itertools.product(ingredients, repeat=start))\n",
    "        while len(comb) < num_line:\n",
    "            start += 1\n",
    "            comb = list(itertools.product(ingredients, repeat=start))\n",
    "        \n",
    "        comb = [\"\".join(i) for i in comb]\n",
    "\n",
    "        return comb[:num_line]\n",
    "    elif idx_opt == \"LRT-UUID\":\n",
    "        comb = []\n",
    "        for i in range(num_line):\n",
    "            comb.append(str(uuid.uuid4()))\n",
    "        \n",
    "        return comb\n",
    "    elif idx_opt == \"LRT-NL\":\n",
    "        import wonderwords\n",
    "\n",
    "        w = wonderwords.RandomWord()\n",
    "        adjs = w.random_words(num_line, include_categories=[\"noun\"])\n",
    "        nouns = w.random_words(num_line, include_categories=[\"noun\"])\n",
    "\n",
    "        comb = []\n",
    "        for i, (adj, noun) in enumerate(zip(adjs, nouns)):\n",
    "            comb.append(f\"{adj}-{noun}\")\n",
    "        \n",
    "        return comb\n",
    "    \n",
    "def retrieve_expected(lines, random_line_pos):\n",
    "    correct_line = lines[random_line_pos]\n",
    "    expected_number = re.search(\"<\\d+>\", correct_line)\n",
    "    if expected_number is not None:\n",
    "        expected_number = int(expected_number.group()[1:-1])\n",
    "    else:\n",
    "        print(f\"Got unparsable line: {correct_line}\")\n",
    "\n",
    "    return expected_number, correct_line\n",
    "\n",
    "def generate_prompt_from_lines(lines):\n",
    "    prompt = \"\"\n",
    "    for l in lines:\n",
    "        prompt += l\n",
    "    \n",
    "    return prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "bbfe0587-b2ad-4334-88a7-5f8a62b17f30",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                                                                                                                                                | 0/20 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (92263 > 2048). Running this sequence through the model will result in indexing errors\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:09<00:00,  2.18it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "3687.8080000000004"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random, re\n",
    "\n",
    "RECORD_COUNT = 20\n",
    "\n",
    "ROWS = [4000]\n",
    "output_dir = \".\"\n",
    "\n",
    "avg_len = 0\n",
    "\n",
    "for n in ROWS:\n",
    "    output_path = os.path.join(output_dir, f\"{n}_lines_en.jsonl\")\n",
    "    f = open(output_path, \"w\", encoding=\"utf-8\")\n",
    "\n",
    "    for i in tqdm(list(range(RECORD_COUNT))):          \n",
    "        prompt_header = \"Below is a record of lines I want you to remember. \" + \\\n",
    "                        \"Each line begins with 'line <line index>' and contains \" + \\\n",
    "                        \"a '<REGISTER_CONTENT>' at the end of the line as a numerical value. \" + \\\n",
    "                        \"For each line index, memorize its corresponding <REGISTER_CONTENT>. At \" + \\\n",
    "                        \"the end of the record, I will ask you to retrieve the corresponding \" + \\\n",
    "                        \"<REGISTER_CONTENT> of a certain line index. Now the record start:\\n\\n\"\n",
    "        lines = []\n",
    "\n",
    "        line_idx_opt = \"LRT-NL\"\n",
    "\n",
    "        if line_idx_opt == \"LRT\":\n",
    "            line_idxes = list(range(1, n + 1))\n",
    "            lines.extend([f\"line {i}: REGISTER_CONTENT is <{random.randint(1, 50000)}>\\n\" for i in line_idxes])\n",
    "            random_idx = random.randint(1, n)\n",
    "            random_num = random_idx - 1\n",
    "        else:\n",
    "            line_idxes = generate_line_index(n, line_idx_opt)\n",
    "            lines.extend([f\"line {i}: REGISTER_CONTENT is <{random.randint(1, 50000)}>\\n\" for i in line_idxes])\n",
    "            random_num = random.randint(0, len(line_idxes)-1)\n",
    "            random_idx = line_idxes[random_num]\n",
    "\n",
    "        expected_number, correct_line = retrieve_expected(lines, random_num)\n",
    "        lines.insert(0, f\"{prompt_header}\")\n",
    "        prompt_postfix = f\"\\nNow the record is over. Tell me what is the <REGISTER_CONTENT> in line {random_idx}? I need the number.\"\n",
    "\n",
    "        prompt = generate_prompt_from_lines(lines)\n",
    "\n",
    "        prompt_len = len(tokenizer.encode(prompt))\n",
    "\n",
    "\n",
    "        avg_len += prompt_len/500\n",
    "\n",
    "        \n",
    "        output = {\n",
    "            \"random_idx\": (random_idx, random_num), # this is the line to retrieve\n",
    "            \"expected_number\": expected_number,\n",
    "            \"num_lines\": n,\n",
    "            \"prompt_len\":prompt_len,\n",
    "            \"correct_line\": correct_line,\n",
    "            \"prompt_postfix\": prompt_postfix,\n",
    "            \"prompt\": prompt}\n",
    "\n",
    "        json.dump(output, f, ensure_ascii=False)\n",
    "        f.write(\"\\n\")\n",
    "    f.close()\n",
    "\n",
    "\n",
    "avg_len"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3b3d86e-6da6-44f2-887a-ae1374961fa0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!head -n 1 {n}_lines_en.jsonl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "cf86aebf-b78d-43bf-8aea-d9ced7676855",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      20 4000_lines_en.jsonl\n"
     ]
    }
   ],
   "source": [
    "!wc -l {n}_lines_en.jsonl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2dbb8bc4-8449-43d3-b32b-fe0072a815e7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
